<?php

namespace DataCube\DataCubeAggregation\AI_Toolkit\Classification;

use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\ProbabilityEstimator;
use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\TrainerInterface;
use DataCube\DataCubeAggregation\Exception\CustomException;
use DataCube\DataCubeAggregation\Exception\CustomInvalidArgumentException;
use Symfony\Component\OptionsResolver\OptionsResolver;
use Rubix\ML\Classifiers\KDNeighbors as RubixKDNeighbors;
use Rubix\ML\Graph\Trees\BallTree;
use Rubix\ML\Kernels\Distance\Minkowski;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
/**
 * PHP ML
 */
class KDNeighbors extends BaseClassifier implements TrainerInterface, ProbabilityEstimator
{
    public function __construct(array $options = [])
    {
        $resolver = new OptionsResolver();
        $this->configureOptions($resolver);
        
        $this->options = $resolver->resolve($options);
        $this->classifier = new RubixKDNeighbors(
            $this->options['k'],
            $this->options['weighted'],
            $this->options['spatial']
        );
    }

    /**
     *
     * @param OptionsResolver $resolver
     * @return void
     */
    public function configureOptions(OptionsResolver $resolver): void
    {
        $resolver->setDefaults([
            'k' => 10,
            'weighted' => false,
            'spatial' => new BallTree(40, new Minkowski()),
        ]);
    }

    public function train(array $samples, array $targets): void
    {
        try {
            $this->classifier->train(new Labeled($samples, $targets));
        } catch (\Exception $e) {
        }
    }

    public function predict(array $samples)
    {
        try {
            return $this->classifier->predict(new Unlabeled([$samples]));
        } catch (\Exception $e) {
            throw new CustomException('Could not predict');
        }
    }

    public function predictProbability(array $sample)
    {
        if ($this->options['probabilityEstimates']) {
            try {
                return $this->classifier->predictProbability($sample);
            } catch (\Exception $e) {
            }
        }
        throw new \InvalidArgumentException('To predict probabilities you must build a classifier with $probabilityEstimates set to true.');
    }

    public function tree()
    {
        return $this->classifier->tree();
    }

}